import sys
import cv2
import gym
import highway_env
from gym import spaces
#import stable_baselines
#from sb3_contrib import TRPO
#sys.path.insert(1, "C:/Users/cahit/Documents/GitHub/drlcarsim/learntodrive/stable_baselines")
#import stable_baselines_tf2
import stable_baselines3
from stable_baselines3.common.torch_layers import BaseFeaturesExtractor
from highway_env.envs.common.observation import TimeToCollisionObservation, KinematicObservation
from sb3_contrib import TRPO
import torch as th
import torch.nn as nn
import numpy as np
import tensorflow
import tensorboard
from highway_env.vehicle.kinematics import Performance, Logger
#torch.set_num_threads(16)
situation = "racetrack-v2"
#situation = 'highway-v0'
frameSize = (1280,560)
# out = cv2.VideoWriter('video'+situation+'.avi',cv2.VideoWriter_fourcc(*'DIVX'), 16, frameSize)
out = cv2.VideoWriter('video'+situation+'.avi', cv2.VideoWriter_fourcc(*'mp4v'), 4, frameSize)

import stable_baselines3.common.policies
env = gym.make(situation)
env.configure({
    'offroad_terminal': True,
    "screen_width": 1280,
    "screen_height": 560,
    "renderfps": 60,
    'action': {'type': 'ContinuousAction'},
    'lateral': True,
    'longitudinal': True,
    "other_vehicles": 1, # non-ego vehicles
    'vehicles_count': 1,
    'show_trajectories': False,
    'absolute': False,
    'normalize': False, #False
    'observation': {'type': 'Kinematics'}
#     "observation": {
#         "type": "OccupancyGrid",
#         "vehicles_count": 1,
#         "features": ["presence", "x", "y", "vx", "vy", "cos_h", "sin_h"],
#         "features_range": {
#             "x": [-100, 100],
#             "y": [-100, 100],
#             "vx": [-20, 20],
#             "vy": [-20, 20]
#         },
# }
}
)
# make the loss diverge / go down
# regularization may fix the wobblyness
# newest reward sucks make a better one if regularization doesn't work
env.reset()
n_cpu = 16
batch_size = 64

class Customcnn(BaseFeaturesExtractor):
    """
    :param observation_space: (gym.Space)
    :param features_dim: (int) Number of features extracted.
        This corresponds to the number of unit for the last layer.
    """

    def __init__(self, observation_space: spaces.Box, features_dim: int = 32):
        super().__init__(observation_space, features_dim)
        # We assume CxHxW images (channels first)
        # Re-ordering will be done by pre-preprocessing or wrapper
        n_input_channels = observation_space.shape[0]
        self.cnn = nn.Sequential(
            nn.Conv2d(n_input_channels, 32, kernel_size=8, stride=4, padding=0),
            nn.ReLU(),
            nn.Conv2d(32, 64, kernel_size=1, stride=2, padding=0),
            nn.ReLU(),
            nn.Flatten(),
            #nn.LSTM(n_input_channels, 16, 3, bias = False),
            #nn.ReLU(),
            #nn.Flatten()
        )

        # Compute shape by doing one forward pass
        with th.no_grad():
            n_flatten = self.cnn(
                th.as_tensor(observation_space.sample()[None]).float()
            ).shape[1]

        self.linear = nn.Sequential(nn.Linear(n_flatten, features_dim), nn.ReLU())

    def forward(self, observations: th.Tensor) -> th.Tensor:
        return self.linear(self.cnn(observations))


a = dict(
    features_extractor_class=Customcnn,
    features_extractor_kwargs=dict(features_dim=32)
)

#policy_kwargs = a
policy_kwargs = dict(net_arch = [64, 64])

model = TRPO("MlpPolicy", env,
             learning_rate=0.005,#0.006,
             n_steps=1000,
             batch_size=64, # 128
             gamma=0.99,
             cg_max_steps=15,
             cg_damping=0.1,
             line_search_shrinking_factor=0.8,
             line_search_max_iter=10,
             n_critic_updates=10,
             gae_lambda=0.95,
             use_sde=False,
             sde_sample_freq=-1,
             normalize_advantage=True,
             target_kl=0.02, # look in to this
             sub_sampling_factor=1,
             tensorboard_log="racetrack_trpo/",
             policy_kwargs=policy_kwargs,
             verbose=1,
             seed=None,
             device='cuda',
             _init_setup_model=True,
             )




# # uncomment the lines below if you want to train a new model
model.learn(int(1.5e5), progress_bar=True)
model.save('situation'+'_trpo/model')

# model = TRPO.load('situation'+'_trpo/model')
# model.set_env(env)
# model.learn(int(1e5), progress_bar=True)
# model.save('situation'+'_trpo/model')

print()
print("Done Learning!!")
print()


########## Load and test saved model##############
model = TRPO.load('situation'+'_trpo/model')
#while True:


perfm = Performance()
lolly = Logger()

number_of_runs = 100
for f in range(number_of_runs):
    done = truncated = False
    obs, info = env.reset()
    ego_car = env.controlled_vehicles[0]
    reward = 0
    step_counter = 0
    while (not done) and ego_car.speed > 0.5 and step_counter < 800: # and reward > 0
        # super ugly solution but only way I can think of
        # using the observation calculate the TTC
        # write that to a txt file
        # read from said txt file in the reward calculation stage
        action, _states = model.predict(obs, deterministic=True)
        obs, reward, done, truncated, info = env.step(action)
        #print(obs)
        step_counter += 1
        lolly.file(ego_car)
        #print(f'The reward is: {reward}')
        #print(step_counter)
        
        env.render()
        cur_frame = env.render(mode="rgb_array")
        out.write(cur_frame)
    print(f'it has been {f+1} runs so far')
    perfm.add_measurement(lolly)
    lolly.clear_log()
out.release()
perfm.print_performance()
print('DONE')